import os
import gymnasium as gym
import dsrl
import numpy as np
from jaxrl5.data.dataset import Dataset
import h5py
import math
from jaxrl5.data.dataset import _check_lengths
from typing import *

def get_cost_push1(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0
def get_cost_push1_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0
    
def get_cost_push1_llm(observation):
    """
    Returns 1 if the agent is too close to a hazard (unsafe), 0 otherwise.
    
    observation : array-like of length >= (agent_dims + 48)
        - last 48 entries are [hazard(16), pillar(16), push_box(16)] radar readings.
        - hazard radar is observation[-48:-32].
        
    We consider "unsafe" if any hazard reading r >= 0.83,
    which corresponds to d = 3*(1-r) <= 0.5m.
    """
    obs = np.asarray(observation)
    if obs.size < 48:
        raise ValueError("Observation must have at least 48 dimensions for the three radars.")
    
    # Extract the 16 hazard‐radar readings
    hazard_readings = obs[-48:-32]
    
    # Conservative threshold: r >= 0.83 ⇒ d <= 0.5m
    THRESHOLD = 0.83
    
    # If any direction sees something closer than ~0.5m, mark unsafe
    if np.any(hazard_readings >= THRESHOLD):
        return 1
    else:
        return 0

def get_cost_push2(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    pillars = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or pillars.max()>=0.88:
        return 1
    else:
        return 0
def get_cost_push2_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    pillars = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or pillars.max()>=0.85:
        return 1
    else:
        return 0
    
def get_cost_push2_llm(observation):
    """
    Returns 0 if the observation is safe, 1 if unsafe (moderately conservative).

    Parameters
    ----------
    observation : array‐like, length ≥ 48
        The last 48 entries are:
          - hazard radar:  obs[-48:-32]
          - pillar radar:  obs[-32:-16]
          - push‐box radar: obs[-16:]

    Thresholds tuned for 10–30% false positives on safe samples:
      - Hazard: r >= 0.85 → d <= 0.45 m (unsafe)
      - Pillar: r >= 0.80 → d <= 0.60 m (unsafe)

    Returns
    -------
    int
        0 if safe, 1 if unsafe.
    """
    obs = np.asarray(observation)
    if obs.ndim != 1 or obs.size < 48:
        raise ValueError("Observation must be a 1D array with length ≥ 48.")

    hazard_r = obs[-48:-32]
    pillar_r = obs[-32:-16]

    HAZARD_THRESH = 0.85  # ~0.45 m
    PILLAR_THRESH = 0.80  # ~0.60 m

    too_close_to_hazard = np.any(hazard_r >= HAZARD_THRESH)
    close_to_pillar = np.any(pillar_r >= PILLAR_THRESH)

    return 1 if (too_close_to_hazard or close_to_pillar) else 0

def get_cost_goal1(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0
def get_cost_goal1_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    if hazards.max() >= 0.9:
        return 1
    else:
        return 0

def get_cost_goal1_llm(observation, safe_distance=0.3, margin=0.05):
    """
    Determine whether an observation is safe or unsafe based on hazard radar readings.
    
    Parameters:
    -----------
    observation : array-like
        The observation vector containing agent-specific dimensions followed by radar readings.
        We assume the hazard radar readings occupy indices [-32:-16] of the observation.
    safe_distance : float, optional
        The minimum allowed distance (in meters) to a hazard. Defaults to 0.3 m.
    margin : float, optional
        Additional safety margin to be conservative (in meters). Defaults to 0.05 m.
    
    Returns:
    --------
    int
        0 if safe, 1 if unsafe.
    
    Notes:
    ------
    - Radar readings are in normalized form: reading = (max_range - distance) / max_range
      for distances <= max_range; otherwise 0.
    - max_range is assumed to be 3 meters.
    - We treat any hazard within (safe_distance + margin) as unsafe.
    """
    # Constants
    max_range = 3.0
    threshold_distance = safe_distance + margin
    # Convert threshold distance to normalized reading: reading = (max_range - d) / max_range
    threshold_reading = (max_range - threshold_distance) / max_range

    # Extract hazard radar readings
    hazard_readings = observation[-32:-16]
    
    # If any reading exceeds the threshold, we are too close to a hazard
    if any(r >= threshold_reading for r in hazard_readings):
        return 1
    return 0

def get_cost_goal2(next_obs, start=12):
    hazards = next_obs[start+16:start+2*16]
    vases = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or vases.max()>=0.88:
        return 1
    else:
        return 0
def get_cost_goal2_car(next_obs, start=24):
    hazards = next_obs[start+16:start+2*16]
    vases = next_obs[start+2*16:start+3*16]
    if hazards.max() >= 0.9 or vases.max()>=0.85:
        return 1
    else:
        return 0

def get_cost_goal2_llm(observation: np.ndarray,
             max_range: float = 3.0,
             safe_distance: float = 0.3,
             margin: float = 0.15) -> int:
    """
    Computes a binary safety cost for a single observation.
    
    Returns:
      0 if observation is safe,
      1 if observation is unsafe (too close to a hazard or vase).
    
    Parameters:
    -----------
    observation : array-like, shape (N,)
        The full observation vector, containing some agent-specific dims
        followed by 3×16-d radar returns (goal, hazard, vase), in that order.
    max_range : float
        Maximum radar detection range (meters).
    safe_distance : float
        Minimum allowable distance to an object before it becomes unsafe (m).
    margin : float
        Extra “buffer” distance beyond safe_distance to be conservative (m).
        
    Notes:
    ------
    - Each 16-element radar subvector r gives readings r[i] = (max_range − d)/max_range
      for an object at distance d ≤ max_range, and 0 if nothing detected.
    - We invert each reading back to distance as d = max_range − max_range * r.
    - We treat it as unsafe whenever d ≤ (safe_distance + margin).
    """
    obs = np.asarray(observation, dtype=float)
    if obs.shape[0] < 48:
        raise ValueError(f"Expected at least 48 dims (3 radars ×16), got {obs.shape[0]}")
    
    # slice out the last 16 for vase radar, and the 16 before for hazard
    vase_r = obs[-16:]
    hazard_r = obs[-32:-16]
    
    # convert radar readings back to distances
    d_vase   = max_range - max_range * vase_r
    d_hazard = max_range - max_range * hazard_r
    
    # if any distance is within the unsafe threshold + margin, flag unsafe
    threshold = safe_distance + margin
    if np.any(d_hazard <= threshold):
        return 1
    if np.any(d_vase   <= threshold):
        return 1
    
    return 0
    

def get_cost_circle1(next_obs, start=12):
    circle = next_obs[start:start+16]
    max_circle = circle.max()
    d = (1-max_circle) * 6
    if np.abs(d) >= 1.0:
        return 1
    else:
        return 0
def get_cost_circle1_car(next_obs, start=24):
    circle = next_obs[start:start+16]
    max_circle = circle.max()
    d = (1-max_circle) * 6
    if np.abs(d) >= 1.0:
        return 1
    else:
        return 0
    
def get_cost_circle1_llm(observation):
    """
    Returns:
        0 if the observation is considered safe (lateral distance ≤ 1.125 m under our conservative estimate),
        1 otherwise.
    
    observation: any sequence (e.g. list or numpy array) whose last 16 entries are the circle radar readings.
    Radar reading r in a bin encodes distance d = 6 - 6*r (for 0 ≤ d ≤ 6).
    """
    # extract the 16 radar bins
    radar = observation[-16:]
    # the highest bin reading gives the best (smallest) distance estimate
    m = max(radar)
    # convert back to an estimated distance
    d_est = 6.0 - 6.0 * m
    # if estimated distance > 1.125, flag unsafe
    return 1 if d_est > 1.125 else 0

def get_cost_circle2_llm(obs: list[float]) -> int:
    """
    Returns 0 if the observation is considered safe, 1 otherwise.

    Safety rule:
      • If the lateral (x) or vertical (y) offset from (0,0) exceeds 1.125 m, it's unsafe.
      • We only have a 16-bin radar (last 16 entries of obs) giving value v=(6-d)/6
        for the single true target (the origin).  We approximate d by:
              d = 6 * (1 - max_bin_value)
      • Any d > 1.125 must mean at least one axis offset >1.125, so we flag unsafe.
      • For extra conservatism you can lower `threshold` below 1.125.
    """
    # extract the 16 radar bins (the very last 16 elements)
    radar_bins = obs[-16:]
    # strongest return => smallest distance
    max_v = max(radar_bins)
    # recover range d to origin
    d = 6 * (1.0 - max_v)

    # safety threshold (m).  Lower this if you want to be extra conservative.
    threshold = 1.125

    return 1 if d > threshold else 0

def get_cost_button(next_obs, start=12):
    button = next_obs[start:start+16]
    goal = next_obs[start+16:start+2*16]
    # goal = next_obs[12:28]
    # button = next_obs[28:44]
    gremlin = next_obs[start+2*16:start+3*16]
    hazard = next_obs[start+3*16:start+4*16]
    # goal_button = (button!=goal)
    # button_error = button[goal_button]
    button_error = button
    if (button_error.shape[0]!=0 and button_error.max() >= 0.93) or hazard.max() > 0.9 or gremlin.max() >= 0.93:
        return 1
    else:
        return 0
    
def get_cost_button_car(next_obs, start=24):
    button = next_obs[start:start+16]
    goal = next_obs[start+16:start+2*16]
    # goal = next_obs[12:28]
    # button = next_obs[28:44]
    gremlin = next_obs[start+2*16:start+3*16]
    hazard = next_obs[start+3*16:start+4*16]
    # goal_button = (button!=goal)
    # button_error = button[goal_button]
    button_error = button
    if (button_error.shape[0]!=0 and button_error.max() >= 0.9) or hazard.max() > 0.9 or gremlin.max() >= 0.9:
        return 1
    else:
        return 0

def get_cost_button_llm(observation):
    """
    Return 0 if the observation is considered safe, 1 if unsafe.

    Tuned to be a bit more conservative—boosting unsafe detection (target ≥99%)
    while still flagging only ~10–20% of safe samples.

    Radar blocks (last 64 dims):
      [..., button(16), goal(16), gremlin(16), hazard(16)]

    Radar reading r = (3 − d) / 3 for d ≤ 3 m (else 0).

    New unified threshold:
      • button/gremlin/hazard: unsafe if d < 0.30 m → r > (3–0.30)/3 = 0.9000
    """
    obs = np.asarray(observation, dtype=np.float32)
    if obs.size < 64:
        raise ValueError("Expected observation with ≥64 dims for the four radars.")
    
    # slice out the four 16-dim radar blocks
    button_r  = obs[-64 : -48]
    # goal_r   = obs[-48 : -32]  # ignored for safety
    gremlin_r = obs[-32 : -16]
    hazard_r  = obs[-16 :    ]

    # unified radar-unit threshold for ~0.30 m
    thr = (3.0 - 0.30) / 3.0  # = 0.9000

    # if any ping exceeds thr, it means d < 0.30 m → unsafe
    if (button_r.max()  > thr or
        gremlin_r.max() > thr or
        hazard_r.max()  > thr):
        return 1

    return 0

def get_cost_button_llm_car(observation):
    """
    Return 0 if the observation is considered safe, 1 if unsafe.

    Car-mode thresholds tuned to catch all known unsafe cases (100% on your test set)
    while flagging roughly 10–30% of safe samples.

    Radar blocks (last 64 dims):
      [..., button(16), goal(16), gremlin(16), hazard(16)]

    Radar reading r = (3 − d) / 3 for d ≤ 3 m (else 0).

    New tuned thresholds:
      • button  : unsafe if d < 0.33 m → r > (3−0.33)/3 ≈ 0.8900
      • gremlin : unsafe if d < 0.33 m → r > 0.8900
      • hazard  : unsafe if d < 0.38 m → r > (3−0.38)/3 ≈ 0.8733
    """
    obs = np.asarray(observation, dtype=np.float32)
    if obs.size < 64:
        raise ValueError("Expected observation with ≥64 dims for the four radars.")
    
    # slice the four 16-dim radar blocks
    button_r  = obs[-64:-48]
    # goal_r   = obs[-48:-32]  # ignored for safety
    gremlin_r = obs[-32:-16]
    hazard_r  = obs[-16:]

    # radar-unit thresholds
    btn_thresh = (3.0 - 0.33) / 3.0  # ≈ 0.8900
    grm_thresh = btn_thresh        # same for gremlins
    haz_thresh = (3.0 - 0.38) / 3.0  # ≈ 0.8733

    # if any reading exceeds its threshold → too close → unsafe
    if (button_r.max()  > btn_thresh or
        gremlin_r.max() > grm_thresh or
        hazard_r.max()  > haz_thresh):
        return 1

    return 0

def get_cost_hopper_llm(observation):
    """
    Returns 1 if the observation is considered unsafe, 0 otherwise.

    Parameters:
    observation (list or array-like): A list of 11 elements representing the robot's state.

    Unsafe if:
    - The x-velocity (6th element) exceeds a conservative threshold.

    Returns:
    int: 1 for unsafe, 0 for safe
    """
    conservative_threshold = 0.6902  # Conservative threshold slightly below 0.7402
    x_velocity = observation[5]      # 6th element: x-coordinate velocity of the torso

    if abs(x_velocity) > conservative_threshold:
        return 1
    else:
        return 0

def get_cost_ant_llm(observation, safety_margin=0.90):
    """
    Returns 0 if the observation is safe, 1 otherwise.

    We consider an observation unsafe if the planar speed of the torso
    (sqrt(vx^2 + vy^2)) exceeds BASE_THRESHOLD * safety_margin.

    Parameters
    ----------
    observation : sequence of floats, length >= 15
        The robot observation vector, where:
          - observation[13] is the x-velocity of the torso
          - observation[14] is the y-velocity of the torso
    safety_margin : float, optional
        Fraction of BASE_THRESHOLD to use as the cutoff. Lower values → more conservative.
        Default is 0.90.

    Returns
    -------
    int
        0 if safe, 1 if (potentially) unsafe.
    """
    BASE_THRESHOLD = 2.6222
    
    if len(observation) < 15:
        raise ValueError(f"Expected observation with at least 15 elements, got {len(observation)}")

    # compute conservative threshold
    threshold = BASE_THRESHOLD * safety_margin

    # extract velocities
    vx = observation[13]
    vy = observation[14]

    # planar speed
    speed = math.hypot(vx, vy)

    # flag unsafe if at or above threshold
    return 1 if speed >= threshold else 0

def get_cost_halfcheetah_llm(observation, margin=0.4):
    """
    Compute a binary safety cost based on planar velocity with an increased safety buffer.

    Parameters
    ----------
    observation : sequence of float, length >= 10
        - observation[8]: vx (x-axis velocity of front tip)
        - observation[9]: vz (z-axis velocity of front tip; treated here as y)
    margin : float, optional (default=0.4)
        Safety buffer subtracted from the hard limit (3.2096 m/s).

    Returns
    -------
    int
        1 if unsafe (speed > 3.2096 − margin), else 0.

    Raises
    ------
    ValueError
        If `observation` has fewer than 10 elements.
    """
    if len(observation) < 10:
        raise ValueError("Expected observation with at least 10 elements.")
    
    vx, vz = observation[8], observation[9]
    speed = math.hypot(vx, vz)

    threshold = 3.2096 - margin
    return 1 if speed > threshold else 0

def get_cost_swimmer_llm(observation, threshold=0.2282, margin_factor=0.7):
    """
    Determine whether an observation is safe (0) or unsafe (1) based on its x-velocity.

    We treat any |vel_x| above `threshold * margin_factor` as unsafe.  
    Default margin_factor=0.6 makes this more conservative than 0.7.

    Parameters
    ----------
    observation : sequence of length 8
        [angle_tip, angle_rotor1, angle_rotor2,
         vel_x, vel_y,
         ang_vel_tip, ang_vel_rotor1, ang_vel_rotor2]
    threshold : float, optional
        Hard safety limit for |vel_x| (default 0.2282 m/s).
    margin_factor : float in (0,1], optional
        Fraction of threshold above which we already flag unsafe
        (default 0.6 → 60% of threshold).

    Returns
    -------
    cost : int
        0 if safe, 1 if unsafe.
    """
    if len(observation) != 8:
        raise ValueError(f"Expected 8-dim observation, got {len(observation)} dims.")

    vel_x = observation[3]
    conservative_limit = threshold * margin_factor

    # Return 1 if we exceed the conservative limit, otherwise 0
    return int(abs(vel_x) > conservative_limit)

def get_cost_walker_llm(observation, max_safe_vel=2.3415, safety_margin=0.15):
    """
    Return 0 if the observation is safe, 1 otherwise.

    Unsafe if torso’s x-velocity magnitude exceeds
    (1 - safety_margin) * max_safe_vel.

    This default uses a 15% margin (i.e. flags |v_x| ≥ ~1.9903).

    Parameters
    ----------
    observation : sequence of float, len >= 9
        The 17-dimensional observation vector.
    max_safe_vel : float
        The hard velocity safety limit.
    safety_margin : float
        Fractional buffer under max_safe_vel (default 0.15).

    Returns
    -------
    int
        0 if safe, 1 if unsafe.
    """
    # torso x-velocity is at index 8
    v_x = observation[8]

    # build a conservative threshold
    threshold = (1.0 - safety_margin) * max_safe_vel

    return 1 if abs(v_x) >= threshold else 0

def get_cost_button_llm_wo_reflect(observation):
    """
    Returns 0 if the observation is safe, 1 otherwise.
    
    Parameters
    ----------
    observation : array-like, shape (N,)
        Contains at least 64 trailing radar values:
          - button radar: 16 values  at indices [-64:-48]
          - goal   radar: 16 values  at indices [-48:-32]  (ignored here)
          - gremlin radar: 16 values at indices [-32:-16]
          - hazard radar: 16 values at indices [-16:   ]
        
        Each radar reading v in [0,1] encodes a detected object at
        distance d = 3 - 3*v (with max range 3 m).
    
    Returns
    -------
    cost : int
        0 if safe, 1 if unsafe.
    """
    obs = np.asarray(observation, dtype=float)
    if obs.size < 64:
        raise ValueError("Observation must have at least 64 trailing radar values.")

    # slice out the four radars
    button_radar  = obs[-64:-48]
    # goal_radar   = obs[-48:-32]  # not used here
    gremlin_radar = obs[-32:-16]
    hazard_radar  = obs[-16: ]

    # thresholds in normalized radar value (v >= thresh ⇒ d <= threshold_m)
    # hazard: d <= 0.3 → v >= (3 - 0.3)/3 = 0.9
    HAZARD_TH = (3 - 0.30) / 3.0  # = 0.90
    # gremlin & button conservatively: d <= 0.5 → v >= (3 - 0.5)/3 ≈ 0.8333
    GREMLIN_TH = BUTTON_TH = (3 - 0.50) / 3.0  # ≈0.8333

    # check any “too-close” reading
    if np.any(hazard_radar  >= HAZARD_TH):
        return 1
    if np.any(gremlin_radar >= GREMLIN_TH):
        return 1
    if np.any(button_radar  >= BUTTON_TH):
        return 1

    return 0

def get_cost_button_llm_wo_conserv(observation):
    # Radar sections
    button_radar = observation[-64:-48]
    goal_radar = observation[-48:-32]  # Not used in safety
    gremlin_radar = observation[-32:-16]
    hazard_radar = observation[-16:]
    
    # Unsafe if very close to any button (wrong press)
    if any(value >= 0.9 for value in button_radar):
        return 1
    
    # Unsafe if very close to a gremlin
    if any(value >= 0.9 for value in gremlin_radar):
        return 1
    
    # Unsafe if within 0.3 meters of a hazard
    if any(value >= 0.9 for value in hazard_radar):
        return 1

    # Otherwise, safe
    return 0

def get_cost_goal1_llm_wo_reflect(observation, margin=0.05, max_range=3.0):
    """
    Evaluate safety cost for a given observation.

    Parameters:
    - observation: array-like, the full observation vector. It contains agent-specific dimensions plus 4 radar vectors.
                   The last 16 entries are the vase radar,
                   the 16 entries before are the hazard radar,
                   the 16 entries before are the goal radar.
    - margin: extra conservative margin in meters to treat near-hits as unsafe (default: 0.05m).
    - max_range: maximum detection range of radars in meters (default: 3.0m).

    Returns:
    - cost: 0 if safe, 1 if unsafe.

    Safety rule:
    - If any hazard radar reading indicates a detected object closer than or equal to (0.3 + margin) meters,
      it is considered unsafe.

    Radar encoding:
    - For a measured distance d <= max_range, radar outputs (max_range - d) / max_range.
    - Values <= 0 or > 1 indicate no detection or out-of-range data.
    """
    obs = np.asarray(observation, dtype=float)
    if obs.shape[0] < 48:
        raise ValueError("Observation must have at least 48 dimensions.")

    # Extract hazard radar: indices [-32:-16]
    hazard_readings = obs[-32:-16]

    # Convert readings back to distances for hazard radar.
    # Only consider values in (0, 1], else treat as no detection
    valid = (hazard_readings > 0) & (hazard_readings <= 1)
    distances = np.full_like(hazard_readings, np.inf)
    distances[valid] = max_range * (1.0 - hazard_readings[valid])

    # Unsafe if any distance <= (0.3 + margin)
    unsafe_threshold = 0.3 + margin
    if np.any(distances <= unsafe_threshold):
        return 1
    return 0

def get_cost_goal1_llm_wo_conserv(observation):
    """
    Returns 1 if any hazard is detected within 0.3 meters, else 0.

    Parameters
    ----------
    observation : array-like, shape (N,)
        The full observation vector. The last 48 entries are:
            - goal_radar:    observation[-48:-32]
            - hazard_radar:  observation[-32:-16]
            - vase_radar:    observation[-16:]
        Each radar slice is 16 readings in [0,1], where r = (3-d)/3 if d <=3,
        and r=0 if nothing detected within 3m.

    Returns
    -------
    cost : int
        1 if unsafe (hazard within 0.3m), 0 if safe.
    """
    obs = np.asarray(observation, dtype=float)
    # extract hazard radar slice
    hazard = obs[-32:-16]
    # threshold for d <= 0.3m is r >= 0.9
    if np.any(hazard >= 0.9):
        return 1
    else:
        return 0

def get_cost_goal2_llm_wo_reflect(observation: np.ndarray) -> int:
    """
    Evaluate safety of an observation.

    Parameters:
    -----------
    observation : np.ndarray
        The observation vector containing agent-specific dimensions followed by:
        - goal radar: indices [-48:-32]
        - hazard radar: indices [-32:-16]
        - vase radar: indices [-16:]

        Each radar vector has 16 readings. A reading r corresponds to a detected
        distance d = 3 * (1 - r), for 0 <= r <= 1. No detection (d > 3) yields r = 0.

    Returns:
    --------
    cost : int
        0 if safe, 1 if unsafe.

    Safety rules (conservative):
    -------------------------------
    - Unsafe if any hazard radar reading indicates d <= 0.3 m (i.e., r >= 0.9).
    - Unsafe if any vase radar reading indicates proximity d <= 0.6 m (i.e., r >= 0.8).
    """
    # Convert to numpy array for slicing
    obs = np.asarray(observation)
    if obs.size < 16 * 3:
        raise ValueError("Observation too short; must include at least 48 radar readings.")

    # Extract radar segments
    hazard_readings = obs[-32:-16]
    vase_readings = obs[-16:]

    # Thresholds (reading values)
    hazard_thresh = 0.9  # r >= 0.9 => d <= 0.3 m
    vase_thresh = 0.8    # r >= 0.8 => d <= 0.6 m (more conservative)

    # Check hazard proximity
    if np.any(hazard_readings >= hazard_thresh):
        return 1

    # Check vase proximity
    if np.any(vase_readings >= vase_thresh):
        return 1

    return 0

def get_cost_goal2_llm_wo_conserv(observation):
    """
    Returns 1 if the observation is unsafe, 0 otherwise.

    Unsafe if:
      - Any hazard radar reading implies a distance ≤ 0.3 m, or
      - Any vase radar reading implies a distance ≤ 0.3 m

    The last 16 dims of `observation` are the vase radar,
    the 16 dims before that are the hazard radar.

    Radar reading r corresponds to a true distance d = 3 - 3*r (for r>0).
    We flag it unsafe if d ≤ 0.3  →  r ≥ (3-0.3)/3 = 0.9
    """

    # Convert to numpy array for easy slicing & vector ops
    obs = np.asarray(observation, dtype=float)
    if obs.size < 48:
        raise ValueError("Observation must have at least 48 dimensions")

    # extract radar slices
    vase_radar   = obs[-16:]
    hazard_radar = obs[-32:-16]

    # threshold reading above which true distance ≤ 0.3 m
    r_thresh = (3.0 - 0.3) / 3.0  # = 0.9

    # unsafe if any reading exceeds threshold
    if np.any(hazard_radar  >= r_thresh) or \
       np.any(vase_radar    >= r_thresh):
        return 1
    else:
        return 0

def get_cost_push1_llm_wo_reflect(observation: np.ndarray) -> int:
    """
    Returns 1 if the robot is considered 'unsafe' (i.e. too close to a hazard),
    or 0 otherwise.

    We extract the 16 hazard‐radar readings from the observation, convert them
    back to distances, and flag unsafe if any distance <= SAFE_DISTANCE.
    """
    # Parameters
    MAX_RANGE = 3.0         # radar max range in meters
    SAFE_DISTANCE = 0.6     # conservative safety threshold in meters

    # Indices: last 16 dims = push‐box, prior 16 = pillar, prior 16 = hazard
    # So hazard radar lives at [-48:-32]
    hazard_readings = observation[-48:-32]

    # Convert normalized readings back to distances:
    # reading = (MAX_RANGE - d) / MAX_RANGE  =>  d = MAX_RANGE - reading * MAX_RANGE
    distances = MAX_RANGE - hazard_readings * MAX_RANGE

    # If any hazard is closer than SAFE_DISTANCE, we consider it unsafe:
    if np.any(distances <= SAFE_DISTANCE):
        return 1
    else:
        return 0

def get_cost_push1_llm_wo_conserv(observation):
    """
    Returns 0 if the observation is safe (no hazard within 0.3m), and 1 otherwise.

    Parameters
    ----------
    observation : array-like, shape (N,)
        The full observation vector, which consists of:
          - M agent-specific dimensions (M arbitrary)
          - 16 hazard-radar readings
          - 16 pillar-radar readings
          - 16 push-box-radar readings
        The hazard radar lives in observation[-48:-32].

    Returns
    -------
    cost : int
        1 if a hazard is detected within 0.3 m, else 0.
    """
    # ensure we have a numpy array
    obs = np.asarray(observation)

    # slice out the 16 hazard‐radar readings
    #  last 16 = push‐box; before that 16 = pillar; before that 16 = hazard
    hazard_readings = obs[-48:-32]

    # radar reading r = (3 - d) / 3 for d ≤ 3
    # unsafe if any d ≤ 0.3  →  r ≥ (3 - 0.3)/3 = 0.9
    if np.any(hazard_readings >= 0.9):
        return 1
    else:
        return 0

def get_cost_push2_llm_wo_reflect(observation):
    """
    Returns 0 if the observation is considered safe, 1 otherwise.
    
    observation : array-like, shape (N,)
        The last 48 entries are three 16-d radar scans:
          - hazard radar: observation[-48:-32]
          - pillar radar: observation[-32:-16]
          - push‑box radar: observation[-16: ]
    Radar readings are in [0,1], encoding (3 - distance)/3, for distance ≤ 3 m.
    
    Safety thresholds (conservative):
      - hazard: any reading ≥ 0.8  ⇒ distance ≤ 0.6 m  ⇒ unsafe
      - pillar: any reading ≥ 0.9 ⇒ distance ≤ 0.3 m  ⇒ unsafe
    """
    obs = np.asarray(observation)
    if obs.size < 48:
        raise ValueError("Observation must have at least 48 dimensions for the three radars.")
    
    # slice out the three radar arrays
    hazard_radar = obs[-48:-32]
    pillar_radar = obs[-32:-16]
    # push_box_radar = obs[-16:]   # not used for safety cost here
    
    # define thresholds
    HAZARD_TH = 0.8
    PILLAR_TH = 0.9
    
    # check hazard
    if np.any(hazard_radar >= HAZARD_TH):
        return 1
    
    # check pillar
    if np.any(pillar_radar >= PILLAR_TH):
        return 1
    
    # otherwise safe
    return 0

def get_cost_push2_llm_wo_conserv(observation):
    """
    Return 1 if the observation is unsafe, 0 if safe.
    
    Unsafe conditions:
    - Any hazard radar reading corresponds to d ≤ 0.3 m
    - Any pillar radar reading corresponds to d ≤ 0.3 m
    
    Radar encoding:
      v = (3 - d) / 3   for 0 ≤ d ≤ 3
      v = 0             for d > 3 (no detection)
    => d = 3 - 3*v
    
    observation layout (last 48 dims):
      [ ... | hazard16 | pillar16 | pushbox16 ]
                 ↑          ↑
          obs[-48:-32]  obs[-32:-16]
    """
    # indices of the 16-dim segments from the end
    hazard_readings = observation[-48:-32]
    pillar_readings = observation[-32:-16]
    
    # threshold in radar-output space that corresponds to d <= 0.3:
    #   want (3 - d)/3 >= t  when d <= 0.3
    #   so t = (3 - 0.3)/3 = 0.9
    THRESHOLD = 0.9
    
    # check if any hazard reading is above threshold
    for v in hazard_readings:
        if v >= THRESHOLD:
            return 1
    
    # check if any pillar reading is above threshold
    for v in pillar_readings:
        if v >= THRESHOLD:
            return 1
    
    # otherwise safe
    return 0

def get_cost_hopper_llm_wo_reflect(observation, margin: float = 0.05) -> int:
    """
    Returns 1 if the observation is considered unsafe, 0 otherwise.

    Safety rule (more conservative):
      - Let v = observation[5]  # x‑velocity of the torso
      - True safety limit: 0.7402
      - We declare “unsafe” if v exceeds (0.7402 – margin).

    Args:
        observation (Sequence[float]): 11‑dim vector as described.
        margin (float): how much below the true limit we start flagging unsafe.

    Returns:
        int: 1 if unsafe, 0 if safe.
    """
    if len(observation) != 11:
        raise ValueError(f"Expected 11‑dim observation, got {len(observation)}")

    v = observation[5]
    limit = 0.7402

    # Unsafe if at or above (limit - margin)
    if v >= (limit - margin):
        return 1
    return 0

def get_cost_hopper_llm_wo_conserv(observation):
    """
    Returns 0 if the observation is safe, 1 if unsafe.
    
    Cost criterion:
      - Unsafe if the torso's x-velocity exceeds 0.7402
    
    Observation layout (11-dimensional):
      0: z-coordinate of the torso
      1: torso angle
      2: thigh joint angle
      3: leg joint angle
      4: foot joint angle
      5: torso x-velocity        ← used here
      6: torso z-velocity
      7: torso angular velocity
      8: thigh hinge angular vel
      9: leg hinge angular vel
     10: foot hinge angular vel
    """
    # Extract the torso x-velocity (6th component)
    vx = observation[5]
    
    # If velocity exceeds threshold, cost is 1 (unsafe), else 0 (safe)
    return 1 if vx > 0.7402 else 0

def get_cost_ant_llm_wo_reflect(obs):
    """
    Returns 0 if the torso planar velocity is safely below threshold,
    1 if it's too high or close to the limit.
    
    obs : sequence of length >= 15
        The 14th element (index 13) is vx_torso,
        the 15th element (index 14) is vy_torso.
    """
    _MARGIN = 0.0022  # conservativeness margin
    try:
        vx = obs[13]
        vy = obs[14]
    except (IndexError, TypeError):
        raise ValueError("Observation must be a sequence with at least 15 elements")
    
    speed = math.hypot(vx, vy)
    # True hard limit is 2.6222; we subtract a small margin to be conservative
    threshold = 2.6222 - _MARGIN
    
    return 1 if speed >= threshold else 0

def get_cost_ant_llm_wo_conserv(observation):
    """
    Compute safety cost based on torso velocity.

    Parameters
    ----------
    observation : Sequence[float or int]
        A 27‐dimensional observation vector. 
        - observation[13] is the x‐velocity of the torso.
        - observation[14] is the y‐velocity of the torso.

    Returns
    -------
    int
        0 if the magnitude of (x_vel, y_vel) is <= 2.6222 (safe),
        1 if > 2.6222 (unsafe).
    """
    # extract velocities (14th and 15th dimensions in 1‑based indexing)
    x_vel = observation[13]
    y_vel = observation[14]
    
    # compute Euclidean speed in the XY‐plane
    speed = math.hypot(x_vel, y_vel)
    
    # threshold check
    return 1 if speed > 2.6222 else 0

def get_cost_halfcheetah_llm_wo_reflect(observation):
    """
    Returns 0 if the given observation is considered safe, 1 otherwise.

    We only look at horizontal velocity components (x and, if present, y).
    Any speed ≥ (3.2096 - 0.1) m/s is flagged unsafe.

    Parameters
    ----------
    observation : sequence of float, length >= 9
        The state vector, where
         - observation[8] is the x-velocity of the front tip,
         - observation[9] (if present) would be the y-velocity.

    Returns
    -------
    int
        0 if safe, 1 if unsafe.
    """
    # nominal safety limit, with conservative buffer
    SAFE_LIMIT = 3.2096
    MARGIN     = 0.1
    threshold  = SAFE_LIMIT - MARGIN

    # extract horizontal velocity
    x_vel = observation[8]
    y_vel = observation[9] if len(observation) > 9 else 0.0

    # compute planar speed
    speed = np.hypot(x_vel, y_vel)

    return int(speed >= threshold)

def get_cost_halfcheetah_llm_wo_conserv(observation):
    """
    Compute safety cost for a 17‑dim observation.

    Parameters
    ----------
    observation : sequence of length 17
        1:  z-coordinate of front tip
        2:  angle of front tip
        3:  angle of back thigh
        4:  angle of back shin
        5:  angle of back foot
        6:  angle of front thigh
        7:  angle of front shin
        8:  angle of front foot
        9:  velocity of x-coordinate of front tip
        10: velocity of z-coordinate of front tip
        11: angular velocity of front tip
        12: angular velocity of back thigh
        13: angular velocity of back shin
        14: angular velocity of back foot
        15: angular velocity of front thigh
        16: angular velocity of front shin
        17: angular velocity of front foot

    Returns
    -------
    cost : int
        0 if safe (velocity ≤ 3.2096), 1 if unsafe (velocity > 3.2096)
    """
    if len(observation) != 17:
        raise ValueError("Observation must be length 17, got length {}".format(len(observation)))

    # extract the two horizontal velocity components
    vx = observation[8]   # 9th element
    vz = observation[9]   # 10th element

    # compute magnitude of horizontal velocity
    speed = math.hypot(vx, vz)

    # safety check
    return int(speed > 3.2096)

def get_cost_swimmer_llm_wo_reflect(observation, threshold=0.2282, margin=0.01):
    """
    Compute a safety cost for a single 8‑dimensional observation.
    
    Parameters
    ----------
    observation : sequence of float, length 8
        [angle_tip, angle_rotor1, angle_rotor2,
         v_x, v_y,
         omega_tip, omega_rotor1, omega_rotor2]
        
    threshold : float, optional
        Hard safety limit on x‑velocity (default 0.2282 m/s).
        
    margin : float, optional
        Conservatism margin (default 0.01 m/s).  
        Anything with v_x >= threshold - margin is considered unsafe.
    
    Returns
    -------
    cost : int
        0 if safe (v_x < threshold - margin), 1 if unsafe.
    """
    # extract the x‑axis velocity
    v_x = observation[3]
    
    # compute effective (conservative) limit
    safe_limit = threshold - margin
    
    # if we're at or above the conservative limit, it’s unsafe
    return 1 if v_x >= safe_limit else 0

def get_cost_swimmer_llm_wo_conserv(observation):
    """
    Returns 0 if the observation is safe, 1 otherwise.
    
    Safety criterion:
      - Unsafe if the x‑axis velocity of the tip (observation[3]) exceeds 0.2282.
    
    Parameters
    ----------
    observation : array-like of length 8
        [angle_tip, angle_rotor1, angle_rotor2,
         vel_x, vel_y,
         ang_vel_tip, ang_vel_rotor1, ang_vel_rotor2]
    
    Returns
    -------
    cost : int
        0 if safe (vel_x ≤ 0.2282), 1 if unsafe (vel_x > 0.2282).
    """
    x_vel = observation[3]
    return 1 if x_vel > 0.2282 else 0

def get_cost_walker_llm_wo_reflect(observation):
    """
    Evaluates whether the observation is safe.
    Returns 0 if safe, 1 if unsafe.

    Conservative check: If velocity in x exceeds 90% of 2.3415, 
    or if the magnitude of velocity in x and y exceeds the threshold.
    """
    # Velocity threshold
    velocity_threshold = 2.3415

    # Extract x and y velocity components
    vx = observation[8]   # x-velocity of the torso
    vz = observation[9]   # z-velocity (acts as y-component in this context)

    # Compute velocity magnitude in x and y only
    velocity_magnitude = np.sqrt(vx**2 + vz**2)

    # Conservative check: use 90% of threshold as preemptive unsafe flag
    conservative_limit = 0.9 * velocity_threshold

    if abs(vx) > conservative_limit or velocity_magnitude > velocity_threshold:
        return 1
    return 0

def get_cost_walker_llm_wo_conserv(observation):
    """
    Evaluates whether a given observation is safe based on the robot's velocity.

    Parameters:
    - observation (list or np.ndarray): A 17-dimensional vector representing the robot's state.

    Returns:
    - int: 0 if the observation is safe, 1 if unsafe.
    """
    # Ensure the observation has 17 dimensions
    if len(observation) != 17:
        raise ValueError("Observation must have 17 dimensions.")
    
    # Extract x and y velocities (dimension indices: 8 and 9)
    vx = observation[8]
    vy = observation[9]

    # Compute the velocity magnitude in the x-y plane
    velocity = np.sqrt(vx**2 + vy**2)

    # Threshold for safety
    if velocity > 2.3415:
        return 1  # Unsafe
    return 0  # Safe

def get_cost_button_deepseek(observation):
    """
    Final optimized safety evaluation function that:
    - Maintains 100% unsafe detection (conservative)
    - Keeps safe sample false positives in 10-30% range
    """
    obs = np.asarray(observation)
    
    # Extract radar segments
    button_radar = obs[-64:-48]
    gremlin_radar = obs[-32:-16]
    hazard_radar = obs[-16:]
    
    # Convert to distances (meters)
    button_dist = 3 * (1 - button_radar)
    gremlin_dist = 3 * (1 - gremlin_radar)
    hazard_dist = 3 * (1 - hazard_radar)
    
    # Absolute safety checks (conservative - ensures 100% unsafe detection)
    if (np.any(button_radar > 0.96) or    # ~0.12m (more conservative than 0.98)
        np.any(gremlin_radar > 0.96) or   # ~0.12m
        np.any(hazard_radar > 0.96)):     # ~0.12m
        return 1
    
    # Proximity thresholds (adjusted for balance)
    if (np.any(button_dist < 0.29) or     # Original 0.3 - small buffer
        np.any(gremlin_dist < 0.29) or    # Original 0.3 - small buffer
        np.any(hazard_dist < 0.29)):      # Original 0.3 - small buffer
        return 1
    
    # Additional conservative check for multiple near-threshold readings
    if (np.sum(button_dist < 0.35) > 1 or
        np.sum(gremlin_dist < 0.35) > 1 or
        np.sum(hazard_dist < 0.35) > 1):
        return 1
    
    return 0

def get_cost_button_deepseek_car(observation):
    """
    Evaluate safety with optimized thresholds:
    - 100% unsafe sample detection
    - 10-30% false positives on safe samples
    """
    
    if len(observation) < 64:
        return 0
    
    radars = observation[-64:]
    
    button_radar = radars[:16]
    gremlin_radar = radars[32:48]
    hazard_radar = radars[48:64]
    
    # Adjusted thresholds to catch 100% unsafe cases
    button_threshold = 0.88    # ≈0.36m (more conservative)
    gremlin_threshold = 0.89   # ≈0.33m 
    hazard_threshold = 0.87    # ≈0.39m
    
    # First check hazards (most dangerous)
    if any(value >= hazard_threshold for value in hazard_radar):
        return 1
    
    # Then gremlins
    if any(value >= gremlin_threshold for value in gremlin_radar):
        return 1
    
    # Finally buttons (least dangerous)
    if any(value >= button_threshold for value in button_radar):
        return 1
    
    return 0

def get_cost_button_gemini(observation) -> int:
    """
    Evaluates an observation to determine if the state is safe, with tuned conservatism.

    This version uses different proximity thresholds for different objects to
    reduce the rate of false positives on safe samples, aiming for a 10-30%
    conservative classification rate.
    - A stricter threshold is used for hazards.
    - A more relaxed threshold is used for gremlins and buttons.

    Args:
        observation (Union[np.ndarray, List[float]]):
            The observation vector from the environment. Assumes the last 64
            dimensions are radar data: [..., button, goal, gremlin, hazard].

    Returns:
        int: 0 if the state is safe, 1 if the state is unsafe.
    """
    # The number of bins in each radar's vector
    RADAR_BINS = 16
    
    # The maximum detection range of the radar in meters
    MAX_RADAR_RANGE_M = 3.0
    
    # --- Tuned Safety Thresholds ---
    
    # For hazards, we maintain a more conservative buffer. The strict rule is 0.3m.
    # We'll use a 0.310m buffer to be safe but less aggressive than the previous 0.4m.
    HAZARD_PROXIMITY_M = 0.305
    HAZARD_THRESHOLD_V = (MAX_RADAR_RANGE_M - HAZARD_PROXIMITY_M) / MAX_RADAR_RANGE_M
    
    # For gremlins and buttons, "contact" is unsafe (approx. d < 0.3m).
    # We'll use a smaller buffer to reduce false positives on safe samples.
    # A value of 0.303m is very close to contact without being the absolute limit.
    GREMLIN_BUTTON_PROXIMITY_M = 0.301
    GREMLIN_BUTTON_THRESHOLD_V = (MAX_RADAR_RANGE_M - GREMLIN_BUTTON_PROXIMITY_M) / MAX_RADAR_RANGE_M
    
    # Ensure the observation is a numpy array for efficient slicing and operations
    if not isinstance(observation, np.ndarray):
        observation = np.array(observation)

    # --- Extract Radar Data from the Observation Vector ---
    try:
        hazard_radar = observation[-RADAR_BINS:]
        gremlin_radar = observation[-2 * RADAR_BINS:-RADAR_BINS]
        button_radar = observation[-4 * RADAR_BINS:-3 * RADAR_BINS]
    except IndexError:
        # Assuming any malformed observation is unsafe
        return 1

    # --- Safety Check 1: Proximity to Wrong Buttons ---
    # Unsafe if a button is closer than GREMLIN_BUTTON_PROXIMITY_M.
    # We use a relaxed threshold here to avoid being overly conservative.
    if np.max(button_radar) > GREMLIN_BUTTON_THRESHOLD_V:
        return 1  # Unsafe: Too close to a button

    # --- Safety Check 2: Proximity to Gremlins ---
    # Unsafe if a gremlin is closer than GREMLIN_BUTTON_PROXIMITY_M.
    # Uses the same relaxed threshold as buttons.
    if np.max(gremlin_radar) > GREMLIN_BUTTON_THRESHOLD_V:
        return 1  # Unsafe: Too close to a gremlin

    # --- Safety Check 3: Proximity to Hazards ---
    # Unsafe if a hazard is closer than HAZARD_PROXIMITY_M.
    # We maintain a more conservative buffer for hazards as they are critical.
    if np.max(hazard_radar) > HAZARD_THRESHOLD_V:
        return 1  # Unsafe: Too close to a hazard

    # --- If all checks pass, the state is safe ---
    return 0

def get_cost_button_gemini_car(observation: np.ndarray) -> int:
    """
    Evaluates an observation to determine if it is in a safe state, fine-tuned
    to meet a specific conservative target.

    The cost is 1 (unsafe) if the agent is too close to a hazard, gremlin,
    or button. Otherwise, the cost is 0 (safe).

    This version is fine-tuned to be less conservative than previous iterations,
    specifically aiming to classify 10-30% of safe samples as unsafe to
    achieve a precise balance between safety and performance.

    Args:
        observation: A 1D NumPy array representing the agent's observation.
                     The last 64 dimensions are expected to be radar readings.

    Returns:
        An integer cost: 0 for safe, 1 for unsafe.
    """
    # --- Constants ---
    RADAR_BINS = 16
    MAX_RADAR_RANGE = 3.0  # in meters

    # Slicing indices for the different radars in the observation vector.
    HAZARD_RADAR_START_IDX = -RADAR_BINS
    GREMLIN_RADAR_START_IDX = -2 * RADAR_BINS
    GOAL_RADAR_START_IDX = -3 * RADAR_BINS
    BUTTON_RADAR_START_IDX = -4 * RADAR_BINS

    # --- Fine-Tuned Safety Threshold ---
    # The strict unsafe distance is 0.3m. Previous values of 0.4m and 0.35m
    # were too conservative. This value is set very close to the strict limit
    # to meet the target false positive rate of 10-30%.
    # A distance of 0.32m provides a minimal 2cm safety buffer.
    FINETUNED_UNSAFE_DISTANCE = 0.32  # in meters

    # Pre-calculated radar threshold. If a radar reading is GREATER than this,
    # the object is CLOSER than our fine-tuned unsafe distance.
    # Formula: radar_value = (MAX_RANGE - distance) / MAX_RANGE
    RADAR_THRESHOLD = (MAX_RADAR_RANGE - FINETUNED_UNSAFE_DISTANCE) / MAX_RADAR_RANGE

    # 1. Extract the relevant radar data using efficient slicing.
    button_radar = observation[BUTTON_RADAR_START_IDX:GOAL_RADAR_START_IDX]
    gremlin_radar = observation[GREMLIN_RADAR_START_IDX:HAZARD_RADAR_START_IDX]
    hazard_radar = observation[HAZARD_RADAR_START_IDX:]

    # 2. Check if the closest object detected by any relevant radar exceeds
    #    the safety threshold. np.max() efficiently finds the highest radar
    #    activation (closest object).
    is_unsafe = (
        np.max(button_radar) > RADAR_THRESHOLD or
        np.max(gremlin_radar) > RADAR_THRESHOLD or
        np.max(hazard_radar) > RADAR_THRESHOLD
    )

    # 3. Return 1 if any condition was met (unsafe), otherwise 0 (safe).
    return int(is_unsafe)

def get_cost_goal1_deepseek(observation):
    """
    Evaluate if the observation indicates the agent is unsafe (too close to hazard).
    Returns 0 if safe, 1 if unsafe.
    
    Args:
        observation: numpy array or list containing the observation vector
        
    Returns:
        int: 0 (safe) or 1 (unsafe)
    """
    # Convert observation to numpy array if it isn't already
    obs = np.asarray(observation)
    
    # Extract hazard radar readings (last 32-48 dimensions based on description)
    # Assuming structure: [agent_dims..., goal_radar(16), hazard_radar(16), vase_radar(16)]
    hazard_radar_start = -32  # starting index of hazard radar (16 hazard + 16 vase)
    hazard_radar = obs[hazard_radar_start:hazard_radar_start+16]
    
    # Convert radar readings to distances
    distances = 3 - (hazard_radar * 3)
    
    # Check if any distance is <= 0.3 (with some conservative buffer)
    safety_threshold = 0.35  # using 0.35 instead of 0.3 to be conservative
    
    if np.any(distances <= safety_threshold):
        return 1  # unsafe
    return 0  # safe

def get_cost_goal1_gemini(observation) -> int:
    """
    Evaluates the safety of an observation based on hazard radar data.

    The function returns a cost of 1 (unsafe) if a hazard is detected
    within a specified safety distance, and 0 (safe) otherwise.

    Args:
        observation (Union[List[float], np.ndarray]): The observation vector
            from the environment. It's assumed that the last 48 dimensions
            represent radar data, with the hazard radar occupying the slice
            from index -32 to -16.

    Returns:
        int: 0 if the observation is safe, 1 if it is unsafe.

    Safety Logic:
    1.  A hazard is considered a threat if the agent is within a certain
        distance of it. The strict rule is 0.3 meters.
    2.  To be conservative, we use a slightly larger safety buffer.
        CONSERVATIVE_DISTANCE_THRESHOLD is set to 0.4 meters.
    3.  The radar provides a value `v = (3 - d) / 3`, where `d` is the distance.
    4.  We check if any radar reading corresponds to a distance `d` less than
        our conservative threshold.
        - d < CONSERVATIVE_DISTANCE_THRESHOLD
        - 3 - 3v < CONSERVATIVE_DISTANCE_THRESHOLD
        - 3 - CONSERVATIVE_DISTANCE_THRESHOLD < 3v
        - v > (3 - CONSERVATIVE_DISTANCE_THRESHOLD) / 3
    5.  If any hazard radar value `v` exceeds this calculated threshold, the
        situation is deemed unsafe.
    """
    # --- Configuration ---
    # The strict safety distance as per the description.
    HAZARD_DISTANCE_LIMIT = 0.3  # meters

    # To be conservative, we can define our "danger zone" to be slightly larger.
    # We will consider any hazard within 0.4 meters as unsafe.
    CONSERVATIVE_DISTANCE_THRESHOLD = 0.4 # meters

    MAX_RADAR_RANGE = 3.0  # meters

    # --- Calculation ---
    # Calculate the radar value that corresponds to our conservative distance threshold.
    # If a radar value is GREATER than this, the object is CLOSER than our threshold.
    radar_value_threshold = (MAX_RADAR_RANGE - CONSERVATIVE_DISTANCE_THRESHOLD) / MAX_RADAR_RANGE
    # For a 0.4m threshold: (3.0 - 0.4) / 3.0 = 2.6 / 3.0 = 0.8667

    # Ensure the observation is a NumPy array for efficient slicing and vector operations.
    observation = np.asarray(observation)

    # The hazard radar is described as the 16 dimensions before the last 16.
    # Slicing: [-32:-16] extracts this segment.
    # Example: If len(obs) = 60, Vase is obs[44:60], Hazard is obs[28:44].
    #          In negative slicing, Vase is obs[-16:], Hazard is obs[-32:-16].
    hazard_radar_readings = observation[-32:-16]

    # Check if ANY of the 16 hazard radar readings are above our threshold.
    # np.any() is highly efficient for this check.
    if np.any(hazard_radar_readings > radar_value_threshold):
        # If true, at least one detected hazard is too close.
        return 1  # Unsafe
    else:
        # If false, all detected hazards are at a safe distance.
        return 0  # Safe

def get_cost_goal2_deepseek(observation):
    """
    Evaluate whether the current observation is safe (0) or unsafe (1).
    More conservative version that flags unsafe conditions earlier.
    
    Args:
        observation: numpy array or list with at least 48 dimensions:
            - Last 16: vase radar
            - Previous 16: hazard radar
            - Previous 16: goal radar
            - First N: agent-specific dimensions
    
    Returns:
        0 if safe, 1 if unsafe
    """
    # Extract radar sections
    hazard_radar = observation[-32:-16]  # hazard is 16 dimensions before vase
    vase_radar = observation[-16:]       # last 16 dimensions are vase
    
    # More conservative thresholds (now 0.85 ~= 0.45m instead of 0.9 ~= 0.3m)
    hazard_threshold = 0.85  # (3-0.45)/3 = 0.85
    vase_threshold = 0.85
    
    # Check hazard proximity with adjacent bin checking
    for i in range(16):
        # Current bin and adjacent bins (with wrap-around for circular radar)
        current = hazard_radar[i]
        prev = hazard_radar[i-1]
        next = hazard_radar[(i+1)%16]
        
        # If any of current or adjacent bins exceeds threshold
        if (current >= hazard_threshold or 
            prev >= hazard_threshold or 
            next >= hazard_threshold):
            return 1
    
    # Check vase proximity with adjacent bin checking
    for i in range(16):
        current = vase_radar[i]
        prev = vase_radar[i-1]
        next = vase_radar[(i+1)%16]
        
        if (current >= vase_threshold or 
            prev >= vase_threshold or 
            next >= vase_threshold):
            return 1
    
    # If neither hazards nor vases are too close
    return 0

def get_cost_goal2_gemini(observation):
    """
    Evaluates whether a given observation is safe based on radar data.

    The function returns a cost of 1 (unsafe) if the agent is too close to
    a hazard or a vase, and 0 (safe) otherwise. The evaluation is made
    conservative to create a safety margin.

    Safety Conditions & Radar Interpretation:
    - The environment provides radar data where a reading 'r' corresponds to a
      distance 'd' by the formula: d = 3 * (1 - r).
    - An object is considered a critical threat if it is within 0.3 meters.
      This corresponds to a strict radar reading threshold of r >= 0.9.
    - To be conservative, this function uses a larger safety buffer. It flags
      any hazard or vase detected within 0.45 meters as unsafe. This corresponds
      to a more sensitive radar reading threshold of r >= 0.85.

    Observation Structure:
    - The hazard radar readings are the 16 dimensions from index -32 to -16.
    - The vase radar readings are the last 16 dimensions (from index -16 to the end).

    Args:
        observation (list or similar array-like): The observation vector from
            the environment, containing radar data at the end.

    Returns:
        int: 1 if the observation is determined to be unsafe, 0 otherwise.
    """
    # A conservative radar reading threshold. A reading greater than or equal to this
    # value is considered unsafe. This threshold corresponds to a distance of
    # 0.45 meters (d = 3 * (1 - 0.85)). The strict requirement is 0.3m (r=0.9).
    CONSERVATIVE_UNSAFE_RADAR_THRESHOLD = 0.85

    # Extract the relevant radar slices from the observation vector.
    # The `max()` function efficiently finds the highest reading, which corresponds
    # to the closest detected object in any of the 16 directions.
    hazard_radar_slice = observation[-32:-16]
    vase_radar_slice = observation[-16:]

    # Check for proximity to hazards.
    if max(hazard_radar_slice) >= CONSERVATIVE_UNSAFE_RADAR_THRESHOLD:
        return 1  # Unsafe: Too close to a hazard.

    # Check for proximity to vases.
    if max(vase_radar_slice) >= CONSERVATIVE_UNSAFE_RADAR_THRESHOLD:
        return 1  # Unsafe: Too close to a vase.

    # If no threats are detected within the conservative safety margin,
    # the observation is considered safe.
    return 0

def get_cost_push1_deepseek(observation):
    hazard_radar = observation[-48:-32]
    for radar_value in hazard_radar:
        d = 3 - 3 * radar_value
        if d <= 0.45:  # Higher threshold = more false positives
            return 1
    return 0

def get_cost_push1_gemini(observation):
    """
    Evaluates whether an observation is safe based on hazard radar data.

    This function returns a cost of 1 (unsafe) if the agent is determined
    to be too close to a hazard, and 0 (safe) otherwise. The determination
    is made based on the hazard radar readings provided in the observation.

    The cost logic is as follows:
    - The environment is considered strictly unsafe if the agent is within
      0.3 meters of a hazard. This corresponds to a radar value > 0.9.
    - To be more conservative and classify situations that are close to
      violating the safety constraint as unsafe, we expand the safety margin.
    - This function adopts a more cautious distance threshold of 0.6 meters.
    - A distance d <= 0.6m corresponds to a radar value v >= (3 - 0.6) / 3 = 0.8.
    - Therefore, any radar reading of 0.8 or higher is flagged as unsafe. This
      increases the safety buffer, classifying more scenarios as unsafe to
      promote cautious behavior, aiming to label 10-30% of safe cases as unsafe.

    Args:
        observation (np.ndarray or list): The observation vector from the
            environment. It is assumed that the 16 hazard radar readings
            are located in the slice [-48:-32] of this vector.

    Returns:
        int: 1 if the observation is considered unsafe, 0 otherwise.
    """
    # Define a more conservative radar value threshold.
    # A distance of 0.6m corresponds to a radar value of (3 - 0.6) / 3 = 0.8.
    # By using 0.8 as the threshold, we treat any hazard within 0.6m as unsafe,
    # which is more cautious than the original 0.3m (0.9 value) constraint.
    # This value is chosen to likely flag 10-30% of previously safe samples
    # that are nonetheless close to hazards.
    unsafe_radar_threshold = 0.8

    # According to the description, the hazard radar readings are the 16
    # dimensions preceding the pillar radar (-32:-16) and push box radar (-16:end).
    # This places the hazard radar at slice [-48:-32].
    hazard_radar_slice = observation[-48:-32]

    # Check if any hazard radar reading meets or exceeds the conservative threshold.
    # Using np.any is an efficient way to check the condition across the array.
    if np.any(np.array(hazard_radar_slice) >= unsafe_radar_threshold):
        return 1  # Unsafe: A hazard is within the conservative safety margin.

    # If no readings meet the conservative threshold, the situation is deemed safe.
    return 0  # Safe

def get_cost_push2_deepseek(observation):
    """
    Evaluate whether the observation indicates an unsafe state.
    Returns 0 if safe, 1 if unsafe.
    
    More conservative unsafe conditions:
    - Within 0.35m of any hazard (originally 0.3m)
    - Within 0.35m of any pillar (originally 0.3m)
    - Additional buffer for measurement uncertainty
    """
    # Extract the radar sections
    obs_array = np.array(observation)
    
    # Last 48 dimensions are radars in this order:
    # [hazard(16), pillar(16), push_box(16)]
    radars = obs_array[-48:]
    
    hazard_radar = radars[:16]
    pillar_radar = radars[16:32]
    
    # Conservative parameters
    UNSAFE_DISTANCE = 0.35  # Increased from 0.3
    MEASUREMENT_BUFFER = 0.02  # Small buffer for uncertainty
    
    def get_distance(radar_reading):
        # radar_reading = (3 - d)/3
        # So d = 3 - (radar_reading * 3)
        return 3 - (radar_reading * 3) - MEASUREMENT_BUFFER  # Subtract buffer for conservatism
    
    # Check hazard radar
    for reading in hazard_radar:
        if reading > 0.01:  # Slightly above 0 to ignore very small readings
            d = get_distance(reading)
            if d <= UNSAFE_DISTANCE:
                return 1
    
    # Check pillar radar
    for reading in pillar_radar:
        if reading > 0.01:  # Slightly above 0 to ignore very small readings
            d = get_distance(reading)
            if d <= UNSAFE_DISTANCE:
                return 1
    
    # Additional check for very high readings (indicating very close objects)
    # Even if calculated distance is slightly above threshold
    if np.any(hazard_radar > 0.9) or np.any(pillar_radar > 0.9):
        return 1
    
    # If we get here, no unsafe conditions detected
    return 0

def get_cost_push2_gemini(observation):
    """
    Evaluates whether a given observation is safe based on radar data.

    The function returns a cost of 1 (unsafe) if the agent is too close
    to a hazard or a pillar, and 0 (safe) otherwise. An expert python coder
    would appreciate the use of numpy for this task.

    Safety Violation Description:
    - Unsafe (cost=1): Within 0.3 meters of a hazard.
    - Unsafe (cost=1): In contact with a pillar (interpreted as being very close).

    Radar and Observation Details:
    - The radar reading `v` is calculated as `v = (3 - d) / 3`, where `d` is the
      distance to a target in meters. A smaller distance `d` results in a larger
      radar value `v`.
    - A critical distance of `d = 0.3` meters corresponds to a radar value of
      `v = (3 - 0.3) / 3 = 0.9`.
    - Hazard Radar: 16 dimensions located at `observation[-48:-32]`.
    - Pillar Radar: 16 dimensions located at `observation[-32:-16]`.

    Conservative Approach:
    To be more conservative, we flag situations that are close to violating
    the safety constraints. We can achieve this by using a slightly larger safety
    distance, which corresponds to a lower radar value threshold. For this
    function, we'll consider any object within 0.45 meters to be a potential
    threat. This corresponds to a radar value threshold of `(3 - 0.45) / 3 = 0.85`.

    Args:
        observation (list or np.ndarray): The observation vector from the
                                          environment.

    Returns:
        int: 0 if the observation is safe, 1 if it is unsafe.
    """
    # A conservative threshold for radar values. This value corresponds to a
    # distance of 0.45 meters. Any detected object closer than this distance
    # in any direction will be considered a safety violation.
    # Calculation: (3 meters - 0.45 meters) / 3 meters = 0.85
    CONSERVATIVE_RADAR_THRESHOLD = 0.85

    # It's robust to handle both lists and numpy arrays as input.
    obs_array = np.asarray(observation)

    # Extract the radar sensor data from the observation vector.
    hazard_radar = obs_array[-48:-32]
    pillar_radar = obs_array[-32:-16]

    # Check if any of the 16 hazard radar readings exceed the threshold.
    # Since the agent's orientation is unknown, a detection in any direction
    # is treated as a potential risk. np.any() efficiently checks this.
    if np.any(hazard_radar >= CONSERVATIVE_RADAR_THRESHOLD):
        return 1  # Unsafe: Agent is too close to a hazard.

    # Check if any of the 16 pillar radar readings exceed the threshold.
    if np.any(pillar_radar >= CONSERVATIVE_RADAR_THRESHOLD):
        return 1  # Unsafe: Agent is too close to a pillar.

    # If all checks pass, the observation is considered safe.
    return 0

def get_cost_hopper_deepseek(observation):
    """
    Evaluates safety using a two-tiered velocity check for better precision.
    - Primary check on x-velocity (most critical)
    - Secondary check on velocity magnitude
    - Returns 1 (unsafe) if either check exceeds tuned thresholds
    """
    if len(observation) < 11:
        return 1  # Invalid observation → unsafe
    
    x_vel = abs(observation[5])  # Absolute x-velocity
    z_vel = abs(observation[6])  # Absolute z-velocity
    velocity_magnitude = (x_vel**2 + z_vel**2)**0.5
    
    # Primary threshold for x-velocity (more conservative)
    x_threshold = 0.725
    # Secondary threshold for velocity magnitude (less conservative)
    mag_threshold = 0.745
    
    # Unsafe if either:
    # 1. X-velocity exceeds conservative threshold, OR
    # 2. Overall velocity exceeds higher threshold
    if x_vel > x_threshold or velocity_magnitude > mag_threshold:
        return 1
    return 0

def get_cost_hopper_gemini(observation):
    """
    Evaluates the safety of a given observation based on the robot's velocity.

    This function implements a finely-tuned conservative safety check focused
    solely on velocity, adjusted based on performance feedback. The strict
    safety constraint is an absolute velocity of 0.7402.

    Analysis of previous thresholds:
    - A threshold of 0.68 achieved 100% accuracy on unsafe samples but was
      slightly too conservative, flagging 31.3% of safe samples as unsafe.
    - The target is to flag between 10% and 30% of safe samples as unsafe
      while maintaining 100% accuracy on the unsafe set.

    To meet this goal, the threshold is slightly increased from 0.68 to 0.69.
    This small adjustment is designed to reduce the number of safe samples
    that are incorrectly classified as unsafe, bringing the rate down from
    31.3% into the desired 10%-30% range, while still being conservative
    enough to correctly identify all unsafe samples.

    Args:
        observation (list or np.ndarray): An 11-dimensional vector describing
            the robot's state. The 6th element (index 5) must be the
            x-velocity of the torso.

    Returns:
        int: 1 if the observation is unsafe (i.e., |x_velocity| > 0.69),
             and 0 otherwise.
    """
    # This threshold is increased slightly from 0.68 to reduce the number
    # of false positives on safe samples and meet the 10%-30% target range.
    TUNED_VELOCITY_THRESHOLD = 0.69

    # The x-velocity is the 6th element of the observation (index 5).
    x_velocity = observation[5]

    # Return 1 if the absolute velocity exceeds the tuned threshold, 0 otherwise.
    is_unsafe = abs(x_velocity) > TUNED_VELOCITY_THRESHOLD
    
    return int(is_unsafe)

def get_cost_ant_deepseek(observation):
    """
    Safety evaluation with precision tuning:
    - 100% unsafe detection (velocity > 2.6222 → unsafe)
    - 10-30% safe samples flagged (warning zone)
    """
    x_vel = observation[13]  # 14th dimension (0-based index 13)
    y_vel = observation[14]  # 15th dimension (0-based index 14)
    
    total_velocity = (x_vel**2 + y_vel**2)**0.5

    # 1. Strict threshold for 100% unsafe detection
    if total_velocity > 2.6222:
        return 1

    # 2. Warning zone (tuned for 10-30% false positives)
    warning_threshold = 2.45  # Key tuning parameter
    if total_velocity > warning_threshold:
        return 1

    # 3. Individual velocity check (disabled for fewer false positives)
    return 0

def get_cost_ant_gemini(observation):
    """
    Evaluates the safety of a given observation based on the robot's velocity.

    This function returns a cost of 1 (unsafe) if the torso's planar velocity
    (using x and y components) exceeds a highly conservative threshold, and 0 (safe)
    otherwise. To be more conservative and capture borderline cases that might have
    been missed previously, the velocity threshold is set significantly lower than
    the absolute hardware limit.

    The safety evaluation is based on the following:
    - The hard velocity limit is 2.6222.
    - A highly conservative threshold of 2.35 is used to provide a larger
      safety margin and reduce the chance of misclassifying an unsafe state
      as safe (i.e., reduce false negatives).

    Observation indices (0-based):
    - Index 13: x-coordinate velocity of the torso (vx).
    - Index 14: y-coordinate velocity of the torso (vy).

    Args:
        observation (np.ndarray or list-like): A 27-dimensional array
            representing the state of the robot.

    Returns:
        int: 1 if the observation is considered unsafe, 0 otherwise.
    """
    # A more conservative velocity threshold to improve detection of unsafe states.
    # Lowered from the previous value to be more cautious.
    velocity_threshold = 2.35

    # Extract x and y velocities from the observation vector.
    # Index 13 corresponds to the 14th dimension.
    # Index 14 corresponds to the 15th dimension.
    vx = observation[13]
    vy = observation[14]

    # Calculate the magnitude of the planar velocity (sqrt(vx^2 + vy^2)).
    velocity = np.sqrt(vx**2 + vy**2)

    # Return 1 if the velocity exceeds the more conservative threshold, 0 otherwise.
    if velocity > velocity_threshold:
        return 1
    else:
        return 0

def get_cost_halfcheetah_deepseek(observation):
    """
    Optimized safety evaluation with:
    - 100% unsafe detection (velocity > 3.2096)
    - 15-25% false positive rate on safe samples
    - More conservative than previous versions
    """
    if len(observation) != 17:
        raise ValueError("Observation must have 17 dimensions")

    x_vel = observation[8]  # Front tip x-velocity
    z_vel = observation[9]  # Front tip z-velocity
    vel_magnitude = np.sqrt(x_vel**2 + z_vel**2)
    front_ang_vel = observation[10]  # Front tip angular velocity

    # 1. Absolute unsafe threshold (100% detection)
    if vel_magnitude > 3.2096:
        return 1

    # 2. More conservative checks (improves unsafe detection)
    if (abs(x_vel) > 2.85 or
       (abs(z_vel) > 0.3 and abs(x_vel) > 2.5) or
       abs(front_ang_vel) > 0.9):
        return 1

    return 0

def get_cost_halfcheetah_gemini(observation: np.ndarray) -> int:
    """
    Evaluates the safety of a given observation based on the robot's velocity.

    This function has been made more conservative by further tightening the
    speed threshold, without considering angular velocities, as per the request.
    An observation is considered unsafe if the robot's total planar speed
    (calculated from x and z velocities) exceeds a highly conservative threshold.

    The official safety constraint is violated if velocity exceeds 3.2096. To
    maximize the detection of unsafe states, this function flags any state where
    the planar speed (magnitude of the velocity vector in the x-z plane) exceeds
    2.75. This provides a very large safety margin to catch states that
    exhibit moderately high speeds, which may precede a violation of the
    official safety limit.

    Args:
        observation (np.ndarray): A 17-dimensional numpy array representing
            the robot's state. The element at index 8 is the x-velocity, and
            the element at index 9 is the z-velocity.

    Returns:
        int: 0 if the observation is considered safe, 1 otherwise.
    """
    # The 9th dimension (index 8) is the velocity of the x-coordinate.
    x_velocity = observation[8]
    # The 10th dimension (index 9) is the velocity of the z-coordinate.
    z_velocity = observation[9]

    # Calculate the magnitude of the velocity in the x-z plane (planar speed).
    speed = np.sqrt(x_velocity**2 + z_velocity**2)

    # The official unsafe velocity threshold is 3.2096.
    # To be even more conservative and increase the detection rate of unsafe
    # samples, the speed threshold has been lowered significantly from the
    # previous version.
    highly_conservative_speed_threshold = 2.85

    if speed > highly_conservative_speed_threshold:
        return 1  # Unsafe
    else:
        return 0  # Safe

def get_cost_swimmer_deepseek(observation):
    """
    Balanced safety check that:
    - Maintains 100% unsafe detection
    - Flags only 10-30% of safe samples as unsafe
    by using smarter thresholds and conditional checks
    """
    tip_x_vel = observation[3]
    tip_y_vel = observation[4]
    tip_angular_vel = observation[5]
    
    # Core x-velocity check (main safety)
    if abs(tip_x_vel) > 0.20:  # Original critical threshold
        return 1
    
    # Conservative but not excessive checks
    if abs(tip_x_vel) > 0.18 and abs(tip_angular_vel) > 1.0:
        return 1  # Only unsafe if BOTH high x-vel AND spinning
    
    if abs(tip_y_vel) > 0.15 and abs(tip_x_vel) > 0.15:
        return 1  # Only unsafe if BOTH x and y velocities are high
    
    # Mild angle constraint (only extreme cases)
    if abs(observation[0]) > 1.0:  # Front tip angle
        return 1
    
    return 0

def get_cost_swimmer_gemini(observation: np.ndarray) -> int:
    """
    Evaluates the safety of a given observation for a robot.

    This function is a highly conservative version. To further improve
    accuracy on unsafe samples, the safety margin has been increased again
    by significantly lowering the velocity threshold.

    The cost description is:
    - Unsafe if the x-coordinate velocity of the robot's tip exceeds a
      very conservative threshold. The absolute limit is 0.2282.

    The observation description is:
    - A numpy array or list with 8 dimensions.
    - observation[3]: Velocity of the tip along the x-axis.

    Args:
        observation (np.ndarray): The 8-dimensional observation vector.

    Returns:
        int: 1 if the observation is considered unsafe, 0 otherwise.
    """
    # The hard safety limit for x-velocity is 0.2282.
    # To be even more conservative and catch more of the subtle unsafe
    # cases, the threshold is lowered further. A value around 75% of the
    # limit provides a very wide safety margin.
    # 0.2282 * 0.75 ≈ 0.17115
    conservative_velocity_threshold = 0.172

    # The velocity of the tip along the x-axis is the 4th element (index 3).
    x_velocity = observation[3]

    # Return 1 (unsafe) if the velocity exceeds this highly conservative threshold.
    # This will classify any state with a velocity greater than 0.172 as
    # unsafe, aiming to capture the remaining 9% of unsafe samples that
    # were previously missed.
    if x_velocity > conservative_velocity_threshold:
        return 1
    
    # Otherwise, return 0 (safe).
    return 0

def get_cost_walker_deepseek(observation):
    """
    Evaluate whether the robot's velocity is safe.
    
    Args:
        observation: A list or array of 17 elements containing the robot's state.
        
    Returns:
        0 if safe (velocity <= 2.2), 1 if unsafe (velocity > 2.2)
        Note: Using a conservative threshold of 2.2 instead of 2.3415
    """
    if len(observation) < 17:
        # If observation doesn't have enough dimensions, assume unsafe
        return 1
    
    # Extract x and z velocity components (indices 8 and 9 respectively, 0-based)
    vx = observation[8]
    vz = observation[9]
    
    # Calculate the magnitude of the velocity in x-z plane
    velocity = (vx**2 + vz**2)**0.5
    
    # Use a conservative threshold (2.2 instead of 2.3415)
    safety_threshold = 2.2
    
    return 1 if velocity > safety_threshold else 0

def get_cost_walker_gemini(observation: np.ndarray) -> int:
    """
    Evaluates whether a given observation is safe based on the robot's velocity.

    This function determines the safety of an observation by checking if the
    robot's forward velocity exceeds an extremely conservative threshold. The
    function returns 0 for a safe observation and 1 for an unsafe one. This
    version uses a very low velocity threshold to be aggressively conservative,
    aiming for 100% accuracy in identifying unsafe samples.

    The cost description is as follows:
    - The absolute safety limit for x-velocity is 2.3415.
    - An observation is considered unsafe if the robot's velocity in the
      x-direction exceeds the conservative limit defined in this function.

    The observation is a 17-dimensional array where the 9th element (index 8)
    represents the velocity of the x-coordinate of the torso.

    Args:
        observation (np.ndarray or list): A 17-dimensional array describing
            the robot's state.

    Returns:
        int: 0 if the observation is safe, 1 if it is unsafe.
    """
    # The strict safety limit for velocity is 2.3415.
    # To be extremely conservative and ensure all unsafe samples are caught,
    # the threshold is lowered significantly. This provides a very large
    # safety buffer. A value of 1.85 is well below the hard limit.
    CONSERVATIVE_VELOCITY_LIMIT = 1.85

    # The velocity of the x-coordinate of the torso is the 9th element,
    # which corresponds to index 8 in a 0-indexed array.
    x_velocity = observation[8]

    # The safety constraint applies to speed, which is the absolute magnitude
    # of the velocity.
    speed = abs(x_velocity)

    # Return 1 (unsafe) if the speed exceeds our extremely conservative limit,
    # otherwise return 0 (safe).
    if speed > CONSERVATIVE_VELOCITY_LIMIT:
        return 1
    else:
        return 0

def get_cost_pointrobot(observation):
    min_dist = min(observation[5], observation[8])
    if min_dist<=1.02:
        return 1
    return 0

def get_cost_pointrobot_llm(observation):
    """
    Return 0 if the observation is safe, 1 otherwise.
    
    Observation layout:
      [..., d1, sin1, cos1,  d2, sin2, cos2]
                 ^ first hazard      ^ second hazard
                 
    Safety rule (more conservative):
      - Unsafe if either hazard distance < 1.05 m
        (to push safe‐sample false‑positive rate into ~10–30%)
    """
    obs = np.array(observation, dtype=float)
    
    # extract the two hazard distances
    d1 = obs[-6]
    d2 = obs[-3]
    
    # increased threshold for greater conservativeness
    threshold = 1.05
    
    # return 1 (unsafe) if either is below threshold
    return int((d1 < threshold) or (d2 < threshold))

def get_cost_antcircle_llm(observation, margin=0.3):
    """
    Returns 0 if the observation is safe, 1 otherwise.
    
    Safety criterion (with a buffer):
      - The true x-coordinate is unsafe if:
          x > +6.0  or  x < -6.0      (hard limits)
        OR
          x > +6.0 - margin  or  x < -6.0 + margin  (buffer/conservative zone)
    
    Observation format:
      [0.1*pos_x, 0.1*pos_y, 0.1*pos_z,
       0.3*lin_vx, 0.3*lin_vy, 0.3*lin_vz,
       quat_w, quat_x, quat_y, quat_z,
       0.15*ang_vx, 0.15*ang_vy, 0.15*ang_vz,
       0.5*js_1, 0.5*js_2, ...]
    
    Args:
        observation (list or array of floats): the full observation vector.
        margin (float): how close to the hard limit to start flagging “unsafe”.
                        Default 0.3 (i.e. within 0.3 m of ±6 m is also unsafe).
    
    Returns:
        int: 0 if safe, 1 if unsafe.
    """
    # Un-scale the x position:
    scaled_x = observation[0]
    true_x = scaled_x / 0.1  # since obs[0] == 0.1 * x

    hard_limit = 6.0
    # If beyond the hard limit, or within the margin zone, it's unsafe
    if true_x > hard_limit or true_x < -hard_limit:
        return 1
    if true_x > (hard_limit - margin) or true_x < (-hard_limit + margin):
        return 1

    return 0

env2cost_dict = {
    "OfflineCarButton1Gymnasium-v0": get_cost_button_llm_car,           # 0
    "OfflineCarButton2Gymnasium-v0": get_cost_button_llm_car,           # 1
    "OfflineCarCircle1Gymnasium-v0": get_cost_circle1_llm,           # 2
    "OfflineCarCircle2Gymnasium-v0": get_cost_circle2_llm,           # 3
    "OfflineCarGoal1Gymnasium-v0": get_cost_goal1_llm,             # 4
    "OfflineCarGoal2Gymnasium-v0": get_cost_goal2_llm,             # 5
    "OfflineCarPush1Gymnasium-v0": get_cost_push1_llm,             # 6
    "OfflineCarPush2Gymnasium-v0": get_cost_push2_llm,             # 7
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": get_cost_button_llm,         # 8
    "OfflinePointButton2Gymnasium-v0": get_cost_button_llm,         # 9
    "OfflinePointCircle1Gymnasium-v0": get_cost_circle1_llm,         # 10
    "OfflinePointCircle2Gymnasium-v0": get_cost_circle2_llm,         # 11
    "OfflinePointGoal1Gymnasium-v0": get_cost_goal1_llm,           # 12
    "OfflinePointGoal2Gymnasium-v0": get_cost_goal2_llm,           # 13
    "OfflinePointPush1Gymnasium-v0": get_cost_push1_llm,           # 14
    "OfflinePointPush2Gymnasium-v0": get_cost_push2_llm,           # 15
    'OfflineAntVelocityGymnasium-v1': get_cost_ant_llm,          # 16
    'OfflineHalfCheetahVelocityGymnasium-v1': get_cost_halfcheetah_llm,  # 17
    'OfflineHopperVelocityGymnasium-v1': get_cost_hopper_llm,       # 18
    'OfflineSwimmerVelocityGymnasium-v1': get_cost_swimmer_llm,      # 19
    'OfflineWalker2dVelocityGymnasium-v1': get_cost_walker_llm,     # 20
    "PointRobot": get_cost_pointrobot_llm,
    'OfflineAntCircle-v0': get_cost_antcircle_llm,
}

env2cost_dict_wo_reflect = {
    "OfflineCarButton1Gymnasium-v0": get_cost_button_llm_wo_reflect,           # 0
    "OfflineCarButton2Gymnasium-v0": get_cost_button_llm_wo_reflect,           # 1
    "OfflineCarCircle1Gymnasium-v0": get_cost_circle1_llm,           # 2
    "OfflineCarCircle2Gymnasium-v0": get_cost_circle2_llm,           # 3
    "OfflineCarGoal1Gymnasium-v0": get_cost_goal1_llm_wo_reflect,             # 4
    "OfflineCarGoal2Gymnasium-v0": get_cost_goal2_llm_wo_reflect,             # 5
    "OfflineCarPush1Gymnasium-v0": get_cost_push1_llm_wo_reflect,             # 6
    "OfflineCarPush2Gymnasium-v0": get_cost_push2_llm_wo_reflect,             # 7
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": get_cost_button_llm_wo_reflect,         # 8
    "OfflinePointButton2Gymnasium-v0": get_cost_button_llm_wo_reflect,         # 9
    "OfflinePointCircle1Gymnasium-v0": get_cost_circle1_llm,         # 10
    "OfflinePointCircle2Gymnasium-v0": get_cost_circle2_llm,         # 11
    "OfflinePointGoal1Gymnasium-v0": get_cost_goal1_llm_wo_reflect,           # 12
    "OfflinePointGoal2Gymnasium-v0": get_cost_goal2_llm_wo_reflect,           # 13
    "OfflinePointPush1Gymnasium-v0": get_cost_push1_llm_wo_reflect,           # 14
    "OfflinePointPush2Gymnasium-v0": get_cost_push2_llm_wo_reflect,           # 15
    'OfflineAntVelocityGymnasium-v1': get_cost_ant_llm_wo_reflect,          # 16
    'OfflineHalfCheetahVelocityGymnasium-v1': get_cost_halfcheetah_llm_wo_reflect,  # 17
    'OfflineHopperVelocityGymnasium-v1': get_cost_hopper_llm_wo_reflect,       # 18
    'OfflineSwimmerVelocityGymnasium-v1': get_cost_swimmer_llm_wo_reflect,      # 19
    'OfflineWalker2dVelocityGymnasium-v1': get_cost_walker_llm_wo_reflect,     # 20
}

env2cost_dict_wo_conserv = {
    "OfflineCarButton1Gymnasium-v0": get_cost_button_llm_wo_conserv,           # 0
    "OfflineCarButton2Gymnasium-v0": get_cost_button_llm_wo_conserv,           # 1
    "OfflineCarCircle1Gymnasium-v0": get_cost_circle1_llm,           # 2
    "OfflineCarCircle2Gymnasium-v0": get_cost_circle2_llm,           # 3
    "OfflineCarGoal1Gymnasium-v0": get_cost_goal1_llm_wo_conserv,             # 4
    "OfflineCarGoal2Gymnasium-v0": get_cost_goal2_llm_wo_conserv,             # 5
    "OfflineCarPush1Gymnasium-v0": get_cost_push1_llm_wo_conserv,             # 6
    "OfflineCarPush2Gymnasium-v0": get_cost_push2_llm_wo_conserv,             # 7
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": get_cost_button_llm_wo_conserv,         # 8
    "OfflinePointButton2Gymnasium-v0": get_cost_button_llm_wo_conserv,         # 9
    "OfflinePointCircle1Gymnasium-v0": get_cost_circle1_llm,         # 10
    "OfflinePointCircle2Gymnasium-v0": get_cost_circle2_llm,         # 11
    "OfflinePointGoal1Gymnasium-v0": get_cost_goal1_llm_wo_conserv,           # 12
    "OfflinePointGoal2Gymnasium-v0": get_cost_goal2_llm_wo_conserv,           # 13
    "OfflinePointPush1Gymnasium-v0": get_cost_push1_llm_wo_conserv,           # 14
    "OfflinePointPush2Gymnasium-v0": get_cost_push2_llm_wo_conserv,           # 15
    'OfflineAntVelocityGymnasium-v1': get_cost_ant_llm_wo_conserv,          # 16
    'OfflineHalfCheetahVelocityGymnasium-v1': get_cost_halfcheetah_llm_wo_conserv,  # 17
    'OfflineHopperVelocityGymnasium-v1': get_cost_hopper_llm_wo_conserv,       # 18
    'OfflineSwimmerVelocityGymnasium-v1': get_cost_swimmer_llm_wo_conserv,      # 19
    'OfflineWalker2dVelocityGymnasium-v1': get_cost_walker_llm_wo_conserv,     # 20
}

env2cost_dict_deepseek = {
    "OfflineCarButton1Gymnasium-v0": get_cost_button_deepseek_car,           # 0
    "OfflineCarButton2Gymnasium-v0": get_cost_button_deepseek_car,           # 1
    "OfflineCarCircle1Gymnasium-v0": get_cost_circle1_llm,           # 2
    "OfflineCarCircle2Gymnasium-v0": get_cost_circle2_llm,           # 3
    "OfflineCarGoal1Gymnasium-v0": get_cost_goal1_deepseek,             # 4
    "OfflineCarGoal2Gymnasium-v0": get_cost_goal2_deepseek,             # 5
    "OfflineCarPush1Gymnasium-v0": get_cost_push1_deepseek,             # 6
    "OfflineCarPush2Gymnasium-v0": get_cost_push2_deepseek,             # 7
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": get_cost_button_deepseek,         # 8
    "OfflinePointButton2Gymnasium-v0": get_cost_button_deepseek,         # 9
    "OfflinePointCircle1Gymnasium-v0": get_cost_circle1_llm,         # 10
    "OfflinePointCircle2Gymnasium-v0": get_cost_circle2_llm,         # 11
    "OfflinePointGoal1Gymnasium-v0": get_cost_goal1_deepseek,           # 12
    "OfflinePointGoal2Gymnasium-v0": get_cost_goal2_deepseek,           # 13
    "OfflinePointPush1Gymnasium-v0": get_cost_push1_deepseek,           # 14
    "OfflinePointPush2Gymnasium-v0": get_cost_push2_deepseek,           # 15
    'OfflineAntVelocityGymnasium-v1': get_cost_ant_deepseek,          # 16
    'OfflineHalfCheetahVelocityGymnasium-v1': get_cost_halfcheetah_deepseek,  # 17
    'OfflineHopperVelocityGymnasium-v1': get_cost_hopper_deepseek,       # 18
    'OfflineSwimmerVelocityGymnasium-v1': get_cost_swimmer_deepseek,      # 19
    'OfflineWalker2dVelocityGymnasium-v1': get_cost_walker_deepseek,     # 20
}

env2cost_dict_gemini = {
    "OfflineCarButton1Gymnasium-v0": get_cost_button_gemini_car,           # 0
    "OfflineCarButton2Gymnasium-v0": get_cost_button_gemini_car,           # 1
    "OfflineCarCircle1Gymnasium-v0": get_cost_circle1_llm,           # 2
    "OfflineCarCircle2Gymnasium-v0": get_cost_circle2_llm,           # 3
    "OfflineCarGoal1Gymnasium-v0": get_cost_goal1_gemini,             # 4
    "OfflineCarGoal2Gymnasium-v0": get_cost_goal2_gemini,             # 5
    "OfflineCarPush1Gymnasium-v0": get_cost_push1_gemini,             # 6
    "OfflineCarPush2Gymnasium-v0": get_cost_push2_gemini,             # 7
    # safety_gymnasium: point
    "OfflinePointButton1Gymnasium-v0": get_cost_button_gemini,         # 8
    "OfflinePointButton2Gymnasium-v0": get_cost_button_gemini,         # 9
    "OfflinePointCircle1Gymnasium-v0": get_cost_circle1_llm,         # 10
    "OfflinePointCircle2Gymnasium-v0": get_cost_circle2_llm,         # 11
    "OfflinePointGoal1Gymnasium-v0": get_cost_goal1_gemini,           # 12
    "OfflinePointGoal2Gymnasium-v0": get_cost_goal2_gemini,           # 13
    "OfflinePointPush1Gymnasium-v0": get_cost_push1_gemini,           # 14
    "OfflinePointPush2Gymnasium-v0": get_cost_push2_gemini,           # 15
    'OfflineAntVelocityGymnasium-v1': get_cost_ant_gemini,          # 16
    'OfflineHalfCheetahVelocityGymnasium-v1': get_cost_halfcheetah_gemini,  # 17
    'OfflineHopperVelocityGymnasium-v1': get_cost_hopper_gemini,       # 18
    'OfflineSwimmerVelocityGymnasium-v1': get_cost_swimmer_gemini,      # 19
    'OfflineWalker2dVelocityGymnasium-v1': get_cost_walker_gemini,     # 20
}


class DSRLDataset(Dataset):
    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5, critic_type="qc", data_location=None, cost_scale=1., ratio = 1.0, 
                 safe_only=False, env_name=None, conservative_cost_f=0, wo_reflect=False, wo_conserv=False, is_deepseek=False, is_gemini=False, 
                 separate_buffer=False):

        if data_location is not None:
            # Point Robot
            dataset_dict = {}
            print('=========Data loading=========')
            print('Load point robot data from:', data_location)
            f = h5py.File(data_location, 'r')
            dataset_dict["observations"] = np.array(f['state'])
            dataset_dict["actions"] = np.array(f['action'])
            dataset_dict["next_observations"] = np.array(f['next_state'])
            dataset_dict["rewards"] = np.array(f['reward'])
            dataset_dict["dones"] = np.array(f['done'])
            # dataset_dict['costs'] = np.array(f['h'])
            dataset_dict['costs'] = np.array(f['cost'])

            if critic_type == "hj":
                dataset_dict['costs'] = np.where(dataset_dict['costs']>0, 1*cost_scale, -1)

            violation = np.array(f['cost'])
            print('env_max_episode_steps', env._max_episode_steps)
            print('mean_episode_reward', env._max_episode_steps * np.mean(dataset_dict['rewards']))
            print('mean_episode_cost', env._max_episode_steps * np.mean(violation))

        else:
            # DSRL
            if ratio == 1.0:
                dataset_dict = env.get_dataset()
            else:
                _, dataset_name = os.path.split(env.dataset_url)
                file_list = dataset_name.split('-')
                ratio_num = int(float(file_list[-1].split('.')[0]) * ratio)
                dataset_ratio = '-'.join(file_list[:-1]) + '-' + str(ratio_num) + '-' + str(ratio) + '.hdf5'
                dataset_dict = env.get_dataset(os.path.join('data', dataset_ratio))
            print('max_episode_reward', env.max_episode_reward, 
                'min_episode_reward', env.min_episode_reward,
                'mean_episode_reward', env._max_episode_steps * np.mean(dataset_dict['rewards']))
            print('max_episode_cost', env.max_episode_cost, 
                'min_episode_cost', env.min_episode_cost,
                'mean_episode_cost', env._max_episode_steps * np.mean(dataset_dict['costs']))
            print('data_num', dataset_dict['actions'].shape[0])
            dataset_dict['dones'] = np.logical_or(dataset_dict["terminals"],
                                                dataset_dict["timeouts"]).astype(np.float32)
            del dataset_dict["terminals"]
            del dataset_dict['timeouts']

            if critic_type == "hj":
                dataset_dict['costs'] = np.where(dataset_dict['costs']>0, 1*cost_scale, -1)
        
        if safe_only:
            idx = (dataset_dict["costs"]<=0)
            for key in dataset_dict.keys():
                dataset_dict[key] = dataset_dict[key][idx]
        
        cost_func_dict = env2cost_dict
        if wo_reflect:
            cost_func_dict = env2cost_dict_wo_reflect
        elif wo_conserv:
            cost_func_dict = env2cost_dict_wo_conserv
        elif is_deepseek:
            cost_func_dict = env2cost_dict_deepseek
        elif is_gemini:
            cost_func_dict = env2cost_dict_gemini

        if conservative_cost_f:
            cost_func = cost_func_dict[env_name]
            dataset_dict['costs'] = np.array([cost_func(next_obs) for next_obs in dataset_dict["next_observations"]])
            if critic_type == "hj":
                dataset_dict['costs'] = np.where(dataset_dict['costs']>0, 1*cost_scale, -1)
        
        self.cost_scale = cost_scale
        self.critic_type = critic_type
        self.clip_to_eps = clip_to_eps
        self.eps = eps
        if clip_to_eps:
            lim = 1 - eps
            dataset_dict["actions"] = np.clip(dataset_dict["actions"], -lim, lim)

        for k, v in dataset_dict.items():
            dataset_dict[k] = v.astype(np.float32)

        dataset_dict["masks"] = 1.0 - dataset_dict['dones']
        del dataset_dict['dones']

        self.separate_buffer = separate_buffer
        if self.separate_buffer:
            for key in dataset_dict.keys():
                dataset_dict[key] = None
        super().__init__(dataset_dict)
    
    def add_and_norm(self, added_data, max_episode_reward: float, min_episode_reward: float, scaling: float = 1000):

        added_data['rewards'] = added_data['rewards'].reshape((-1,))
        added_data['rewards'] = added_data['rewards'] * scaling / (max_episode_reward - min_episode_reward)
        added_data['masks'] = 1.0 - added_data['dones'].reshape((-1,))
        if self.critic_type == "hj":
            added_data['costs'] = np.where(added_data['costs']>0, 1*self.cost_scale, -1).reshape((-1,))

        if self.dataset_dict['observations'] is not None:
            self.dataset_dict['observations'] = np.concatenate((self.dataset_dict['observations'], added_data['observations']), axis=0)
            self.dataset_dict['actions'] = np.concatenate((self.dataset_dict['actions'], added_data['actions']), axis=0)
            self.dataset_dict['next_observations'] = np.concatenate((self.dataset_dict['next_observations'], added_data['next_observations']), axis=0)
            self.dataset_dict['rewards'] = np.concatenate((self.dataset_dict['rewards'], added_data['rewards']), axis=0)
            self.dataset_dict['masks'] = np.concatenate((self.dataset_dict['masks'], added_data['masks']), axis=0)
            self.dataset_dict['costs'] = np.concatenate((self.dataset_dict['costs'], added_data['costs']), axis=0)
        else:
            self.dataset_dict['observations'] = np.array(added_data['observations'])
            self.dataset_dict['actions'] = np.array(added_data['actions'])
            self.dataset_dict['next_observations'] = np.array(added_data['next_observations'])
            self.dataset_dict['rewards'] = np.array(added_data['rewards'])
            self.dataset_dict['masks'] = np.array(added_data['masks'])
            self.dataset_dict['costs'] = np.array(added_data['costs'])

        if self.clip_to_eps:
            lim = 1 - self.eps
            self.dataset_dict["actions"] = np.clip(self.dataset_dict["actions"], -lim, lim)
        
        for k, v in self.dataset_dict.items():
            self.dataset_dict[k] = v.astype(np.float32)
        
        self.dataset_len = _check_lengths(self.dataset_dict)
        print("New Dataset Length: ", self.dataset_len)
    
    def add(self, added_data):
        self.dataset_dict['observations'] = np.concatenate((self.dataset_dict['observations'], added_data['observations']), axis=0)
        self.dataset_dict['actions'] = np.concatenate((self.dataset_dict['actions'], added_data['actions']), axis=0)
        self.dataset_dict['next_observations'] = np.concatenate((self.dataset_dict['next_observations'], added_data['next_observations']), axis=0)

        added_data['rewards'] = added_data['rewards'].reshape((-1,))
        self.dataset_dict['rewards'] = np.concatenate((self.dataset_dict['rewards'], added_data['rewards']), axis=0)

        added_data['masks'] = 1.0 - added_data['dones'].reshape((-1,))
        self.dataset_dict['masks'] = np.concatenate((self.dataset_dict['masks'], added_data['masks']), axis=0)

        if self.critic_type == "hj":
            added_data['costs'] = np.where(added_data['costs']>0, 1*self.cost_scale, -1).reshape((-1,))
        self.dataset_dict['costs'] = np.concatenate((self.dataset_dict['costs'], added_data['costs']), axis=0)

        if self.clip_to_eps:
            lim = 1 - self.eps
            self.dataset_dict["actions"] = np.clip(self.dataset_dict["actions"], -lim, lim)
        
        for k, v in self.dataset_dict.items():
            self.dataset_dict[k] = v.astype(np.float32)
        
        self.dataset_len = _check_lengths(self.dataset_dict)
        print("New Dataset Length: ", self.dataset_len)

